// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>

// clang-format off
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#endif
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#endif
// clang-format on

#include "onnx_import/onnx.hpp"
#include "default_opset.hpp"
#include "engines_util/test_case.hpp"
#include "engines_util/test_engines.hpp"
#include "util/test_control.hpp"

NGRAPH_SUPPRESS_DEPRECATED_START

using namespace ngraph;

static std::string s_manifest = "${MANIFEST}";
static std::string s_device = test::backend_name_to_device("${BACKEND_NAME}");

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_bias_gelu) {
    const auto function =
        onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/bias_gelu.onnx"));

    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<float>({0.5488135,
                                0.71518934,
                                0.60276335,
                                0.5448832,
                                0.4236548,
                                0.6458941,
                                0.4375872,
                                0.891773,
                                0.96366274,
                                0.3834415});
    test_case.add_input<float>({0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606});
    test_case.add_expected_output<float>(
        {1.2198428, 1.1112978, 1.0293297, 1.366493, 0.3411342, 1.329408, 0.8051748, 1.354462, 1.8336612, 0.3068893});
    test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma_beta_bias) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/skip_layer_normalization_with_gamma_beta_bias.onnx"));

    std::vector<float> input = {
        0.54881352, 0.71518934, 0.60276335, 0.54488319, 0.42365479, 0.64589411, 0.43758720, 0.89177299,
        0.96366274, 0.38344151, 0.79172504, 0.52889490, 0.56804454, 0.92559665, 0.07103606, 0.08712930,
        0.02021840, 0.83261985, 0.77815676, 0.87001216, 0.97861832, 0.79915857, 0.46147937, 0.78052920,
    };
    std::vector<float> skip = {
        0.11827443, 0.63992101, 0.14335328, 0.94466889, 0.52184832, 0.41466194, 0.26455560, 0.77423370,
        0.45615032, 0.56843394, 0.01878980, 0.61763549, 0.61209571, 0.61693400, 0.94374806, 0.68182027,
        0.35950789, 0.43703195, 0.69763118, 0.06022547, 0.66676670, 0.67063785, 0.21038257, 0.12892629,
    };
    std::vector<float> expected = {
        -0.19721794, -0.42944565, 0.18620640, 0.61282152,  -0.11097327, -0.59518522, 0.13393641,  0.66901535,
        0.04256713,  -0.71902490, 0.23107991, 0.17300847,  -0.04390603, -0.31109563, 0.51021838,  -0.66914201,
        -0.20009395, -0.43313017, 0.67281967, -0.01712347, 0.09767530,  -0.43024653, -0.01836969, -0.29238200,
    };
    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<float>(input);
    test_case.add_input<float>(skip);
    test_case.add_expected_output<float>(expected);
    test_case.run_with_tolerance_as_fp();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma_beta) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/skip_layer_normalization_with_gamma_beta.onnx"));

    std::vector<float> input = {
        0.54881352, 0.71518934, 0.60276335, 0.54488319, 0.42365479, 0.64589411, 0.43758720, 0.89177299,
        0.96366274, 0.38344151, 0.79172504, 0.52889490, 0.56804454, 0.92559665, 0.07103606, 0.08712930,
        0.02021840, 0.83261985, 0.77815676, 0.87001216, 0.97861832, 0.79915857, 0.46147937, 0.78052920,
    };
    std::vector<float> skip = {
        0.11827443, 0.63992101, 0.14335328, 0.94466889, 0.52184832, 0.41466194, 0.26455560, 0.77423370,
        0.45615032, 0.56843394, 0.01878980, 0.61763549, 0.61209571, 0.61693400, 0.94374806, 0.68182027,
        0.35950789, 0.43703195, 0.69763118, 0.06022547, 0.66676670, 0.67063785, 0.21038257, 0.12892629,
    };
    std::vector<float> expected = {
        -0.17974678, -0.23946194, -0.04376268, 0.46959469,  -0.11171167, -0.41859278, -0.11082965, 0.64513868,
        0.07773457,  -0.51403606, -0.13661698, 0.11262375,  -0.05096011, -0.10416907, 0.10070466,  -0.50876135,
        -0.22290939, -0.27663514, 0.55416691,  -0.08064821, 0.04857478,  -0.25121087, -0.15912610, -0.26637587,
    };
    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<float>(input);
    test_case.add_input<float>(skip);
    test_case.add_expected_output<float>(expected);
    test_case.run_with_tolerance_as_fp();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/skip_layer_normalization_with_gamma.onnx"));

    std::vector<float> input = {
        0.54881352, 0.71518934, 0.60276335, 0.54488319, 0.42365479, 0.64589411, 0.43758720, 0.89177299,
        0.96366274, 0.38344151, 0.79172504, 0.52889490, 0.56804454, 0.92559665, 0.07103606, 0.08712930,
        0.02021840, 0.83261985, 0.77815676, 0.87001216, 0.97861832, 0.79915857, 0.46147937, 0.78052920,
    };
    std::vector<float> skip = {
        0.11827443, 0.63992101, 0.14335328, 0.94466889, 0.52184832, 0.41466194, 0.26455560, 0.77423370,
        0.45615032, 0.56843394, 0.01878980, 0.61763549, 0.61209571, 0.61693400, 0.94374806, 0.68182027,
        0.35950789, 0.43703195, 0.69763118, 0.06022547, 0.66676670, 0.67063785, 0.21038257, 0.12892629,
    };
    std::vector<float> expected = {
        -0.10974677, 0.16053806,  -0.26376268, 0.46959469,  -0.04171166, -0.01859277, -0.33082965, 0.64513868,
        0.14773457,  -0.11403608, -0.35661697, 0.11262375,  0.01903989,  0.29583094,  -0.11929534, -0.50876135,
        -0.15290938, 0.12336487,  0.33416691,  -0.08064821, 0.11857478,  0.14878914,  -0.37912610, -0.26637587,
    };
    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<float>(input);
    test_case.add_input<float>(skip);
    test_case.add_expected_output<float>(expected);
    test_case.run_with_tolerance_as_fp();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_dynamic_shapes) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/skip_layer_normalization_dynamic_shapes.onnx"));

    std::vector<float> input = {
        0.54881352, 0.71518934, 0.60276335, 0.54488319, 0.42365479, 0.64589411, 0.43758720, 0.89177299,
        0.96366274, 0.38344151, 0.79172504, 0.52889490, 0.56804454, 0.92559665, 0.07103606, 0.08712930,
        0.02021840, 0.83261985, 0.77815676, 0.87001216, 0.97861832, 0.79915857, 0.46147937, 0.78052920,
    };
    std::vector<float> skip = {
        0.11827443, 0.63992101, 0.14335328, 0.94466889, 0.52184832, 0.41466194, 0.26455560, 0.77423370,
        0.45615032, 0.56843394, 0.01878980, 0.61763549, 0.61209571, 0.61693400, 0.94374806, 0.68182027,
        0.35950789, 0.43703195, 0.69763118, 0.06022547, 0.66676670, 0.67063785, 0.21038257, 0.12892629,
    };
    std::vector<float> gamma = {
        0.31542835,
        0.36371076,
        0.57019675,
        0.43860152,
    };
    std::vector<float> beta = {
        0.98837382,
        0.10204481,
        0.20887676,
        0.16130951,
    };
    std::vector<float> bias = {
        0.65310830,
        0.25329161,
        0.46631077,
        0.24442559,
    };
    std::vector<float> expected = {
        0.76600611, 0.34308332,  -0.48470584, 0.71335256,  1.10028172, -0.13354334, -0.45232186, 0.79840088,
        1.52454257, -0.19450217, -0.13759643, 0.03988872,  1.27861762, 0.39529073,  0.12247884,  -0.52944231,
        0.64228040, 0.21059875,  1.05966032,  -0.14278713, 1.46366918, 0.21215858,  -0.31640187, -0.22832340,
    };

    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<float>(Shape{3, 2, 4}, input);
    test_case.add_input<float>(Shape{3, 2, 4}, skip);
    test_case.add_input<float>(Shape{4}, gamma);
    test_case.add_input<float>(Shape{4}, beta);
    test_case.add_input<float>(Shape{4}, bias);
    test_case.add_expected_output<float>(Shape{3, 2, 4}, expected);
    test_case.run_with_tolerance_as_fp();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/embed_layer_normalization.onnx"));

    std::vector<int> input_ids = {
        8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9,
    };
    std::vector<float> expected_output = {
        -0.06615843, -0.18040463, 0.02199928,  0.01868065,  0.05397778,  -0.11761580, -0.09138932, -0.02506775,
        -0.02368510, -0.10373901, -0.05551499, -0.20972314, 0.01365213,  0.01132561,  -0.08603337, -0.08906764,
        0.09692993,  -0.04444099, -0.02037602, -0.03453060, -0.10214549, -0.13331436, -0.02665862, -0.01228805,
        -0.14232540, -0.07032782, 0.05511986,  -0.00120272, -0.04875736, -0.13051267, -0.05709254, 0.17854357,
        -0.01759873, -0.01819968, 0.07573269,  0.00557164,  0.06232717,  0.00530490,  -0.01565807, -0.14841977,
        -0.02299280, 0.02038561,  -0.00049481, 0.02575402,  0.10081697,  -0.12517214, -0.09316762, -0.00974943,
        -0.03093284, -0.06309240, -0.05551499, -0.20972314, 0.01365213,  0.01132561,  -0.08603337, -0.06176658,
        0.08304203,  -0.05025182, 0.00383657,  -0.02288112, -0.11407227, -0.01386134, -0.04411830, -0.00537948,
        0.00164397,  -0.03739140, 0.09941526,  0.00333974,  -0.04251949, -0.12992151, -0.09509478, -0.11811313,
        -0.03307065, -0.00866115, -0.15162414, 0.01106802,  0.06037656,  0.00035292,  -0.00223284, -0.11215645,
        -0.01390734, 0.07064321,  0.04028325,  -0.00290875, 0.12875907,  -0.12517214, -0.09316762, -0.00974943,
        -0.03093284, -0.06309240, -0.08723789, 0.03130914,  0.03131931,  -0.01526242, 0.20811458,  -0.05696163,
        0.16304255,  -0.02407495, -0.02955675, -0.03086288, -0.08130091, -0.05001551, -0.04875683, 0.00143666,
        -0.12153473, -0.00018507, 0.10957482,  -0.00416618, -0.01612359, -0.11605026, -0.08593204, 0.09055272,
        -0.03054028, -0.03603891, -0.08479506, -0.00034568, 0.03713699,  0.00163411,  -0.01738501, -0.18267182,
    };

    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<int>(input_ids);
    test_case.add_expected_output<float>(expected_output);
    test_case.run_with_tolerance_as_fp(1e-7f);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization_with_segment_embedding) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO,
                             "onnx/com.microsoft/embed_layer_normalization_with_segment_embedding.onnx"));

    std::vector<int> input_ids = {
        8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9,
    };
    std::vector<int> segment_ids = {
        0, 2, 0, 2, 2, 0, 2, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2, 0, 1, 1, 1,
    };
    std::vector<float> expected_output = {
        -0.06044213, -0.14845914, 0.02457689,  0.02091519,  0.09514004,  -0.10280035, -0.02087995, -0.03323204,
        -0.02967127, -0.13447416, -0.05191760, -0.16518904, 0.02340531,  0.02176395,  0.04972410,  -0.07360736,
        0.12192874,  -0.04081530, -0.02338044, -0.05671440, -0.09475864, -0.08944942, -0.03362993, -0.01683486,
        -0.16770349, -0.07382569, 0.06230322,  0.02215859,  -0.05212611, -0.03934773, -0.04748865, 0.18134241,
        -0.01965741, -0.02202452, 0.01973994,  0.01575558,  0.04300199,  0.01436110,  -0.00198062, -0.09065692,
        -0.02923042, -0.00748686, 0.00717049,  0.02638642,  0.12174864,  -0.12973398, -0.11872391, -0.00549398,
        -0.02386289, -0.02210563, -0.03590920, -0.13728066, -0.01337939, 0.01538021,  -0.14687485, -0.05033565,
        0.03818212,  -0.04939338, 0.00961064,  -0.07407621, -0.09624685, 0.05594898,  -0.04948713, -0.01305631,
        -0.03779668, -0.01469170, 0.12346989,  0.02082030,  -0.03449103, -0.06029151, -0.09300473, -0.16308543,
        -0.02370042, 0.01066893,  -0.06523034, 0.00497636,  0.01933458,  -0.00900802, 0.00430878,  -0.13999483,
        -0.02377289, 0.01760014,  0.03896973,  0.00831112,  0.15634246,  -0.11109130, -0.11997811, -0.02304414,
        -0.01989413, -0.12763791, -0.05698400, 0.17125534,  0.00499324,  -0.02953288, 0.09178342,  -0.05001877,
        0.16157132,  -0.02312993, -0.02932195, -0.04914058, -0.07994118, -0.07199102, -0.04517454, 0.01249476,
        -0.07525793, -0.00207180, 0.03993115,  -0.01676321, -0.00214832, -0.16074482, -0.05012497, -0.00552153,
        -0.04302063, -0.00549224, -0.18399858, -0.00767871, -0.02209404, -0.01383207, -0.00082931, -0.19533031,
    };

    std::vector<int> expected_mask_index = {
        0,
        0,
        0,
    };

    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<int>(input_ids);
    test_case.add_input<int>(segment_ids);
    test_case.add_expected_output<float>(expected_output);
    test_case.add_expected_output<int>(expected_mask_index);
    test_case.run_with_tolerance_as_fp(1e-7);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization_with_segment_embedding_and_mask) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO,
                             "onnx/com.microsoft/embed_layer_normalization_with_segment_embedding_and_mask.onnx"));

    std::vector<int> input_ids = {
        8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9,
    };
    std::vector<int> segment_ids = {
        0, 2, 0, 2, 2, 0, 2, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2, 0, 1, 1, 1,
    };
    std::vector<int> mask = {
        1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
    };
    std::vector<float> expected_output = {
        -0.06044213, -0.14845914, 0.02457689,  0.02091519,  0.09514004,  -0.10280035, -0.02087995, -0.03323204,
        -0.02967127, -0.13447416, -0.05191760, -0.16518904, 0.02340531,  0.02176395,  0.04972410,  -0.07360736,
        0.12192874,  -0.04081530, -0.02338044, -0.05671440, -0.09475864, -0.08944942, -0.03362993, -0.01683486,
        -0.16770349, -0.07382569, 0.06230322,  0.02215859,  -0.05212611, -0.03934773, -0.04748865, 0.18134241,
        -0.01965741, -0.02202452, 0.01973994,  0.01575558,  0.04300199,  0.01436110,  -0.00198062, -0.09065692,
        -0.02923042, -0.00748686, 0.00717049,  0.02638642,  0.12174864,  -0.12973398, -0.11872391, -0.00549398,
        -0.02386289, -0.02210563, -0.03590920, -0.13728066, -0.01337939, 0.01538021,  -0.14687485, -0.05033565,
        0.03818212,  -0.04939338, 0.00961064,  -0.07407621, -0.09624685, 0.05594898,  -0.04948713, -0.01305631,
        -0.03779668, -0.01469170, 0.12346989,  0.02082030,  -0.03449103, -0.06029151, -0.09300473, -0.16308543,
        -0.02370042, 0.01066893,  -0.06523034, 0.00497636,  0.01933458,  -0.00900802, 0.00430878,  -0.13999483,
        -0.02377289, 0.01760014,  0.03896973,  0.00831112,  0.15634246,  -0.11109130, -0.11997811, -0.02304414,
        -0.01989413, -0.12763791, -0.05698400, 0.17125534,  0.00499324,  -0.02953288, 0.09178342,  -0.05001877,
        0.16157132,  -0.02312993, -0.02932195, -0.04914058, -0.07994118, -0.07199102, -0.04517454, 0.01249476,
        -0.07525793, -0.00207180, 0.03993115,  -0.01676321, -0.00214832, -0.16074482, -0.05012497, -0.00552153,
        -0.04302063, -0.00549224, -0.18399858, -0.00767871, -0.02209404, -0.01383207, -0.00082931, -0.19533031,
    };
    std::vector<int> expected_mask_index = {
        5,
        3,
        4,
    };

    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<int>(input_ids);
    test_case.add_input<int>(segment_ids);
    test_case.add_input<int>(mask);
    test_case.add_expected_output<float>(expected_output);
    test_case.add_expected_output<int>(expected_mask_index);
    test_case.run_with_tolerance_as_fp(1e-7);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization_dynamic_shapes) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/embed_layer_normalization_dynamic_shapes.onnx"));

    std::vector<int> input_ids = {
        8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9,
    };
    std::vector<int> segment_ids = {
        0, 2, 0, 2, 2, 0, 2, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2, 0, 1, 1, 1,
    };
    std::vector<float> word_embeddings = {
        0.96980906, 0.65314001, 0.17090958, 0.35815218, 0.75068617, 0.60783064, 0.32504722, 0.03842543, 0.63427407,
        0.95894927, 0.65279031, 0.63505888, 0.99529958, 0.58185035, 0.41436860, 0.47469750, 0.62351012, 0.33800763,
        0.67475230, 0.31720173, 0.77834547, 0.94957107, 0.66252685, 0.01357164, 0.62284607, 0.67365962, 0.97194499,
        0.87819350, 0.50962436, 0.05571469, 0.45115921, 0.01998767, 0.44171092, 0.97958672, 0.35944447, 0.48089352,
        0.68866116, 0.88047588, 0.91823548, 0.21682213, 0.56518888, 0.86510259, 0.50896895, 0.91672295, 0.92115760,
        0.08311249, 0.27771857, 0.00935670, 0.84234208, 0.64717412,
    };
    std::vector<float> position_embeddings = {
        0.84138614, 0.26473016, 0.39782074, 0.55282146, 0.16494046, 0.36980811, 0.14644176, 0.56961840,
        0.70373726, 0.28847644, 0.43328807, 0.75610667, 0.39609829, 0.89603841, 0.63892108, 0.89155442,
        0.68005556, 0.44919774, 0.97857094, 0.11620191, 0.76702368, 0.41182014, 0.67543906, 0.24979627,
        0.31321833, 0.96541619, 0.58846509, 0.65966839, 0.53320622, 0.23053302, 0.39486930, 0.61880857,
        0.47486752, 0.47013220, 0.71607453, 0.28799102, 0.38346222, 0.74916983, 0.87845218, 0.10286336,
    };
    std::vector<float> segment_embeddings = {
        0.09237389,
        0.35404667,
        0.55181628,
        0.03362509,
        0.96896178,
        0.32099724,
        0.22126268,
        0.14126390,
        0.09725992,
        0.98404223,
        0.26034093,
        0.53702253,
        0.44792616,
        0.09956909,
        0.35231167,
    };
    std::vector<float> gamma = {
        0.46924916,
        0.84114015,
        0.90464777,
        0.03755938,
        0.50831544,
    };
    std::vector<float> beta = {
        0.16684751,
        0.77905101,
        0.86493331,
        0.41139671,
        0.13997258,
    };
    std::vector<int> mask = {
        1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
    };
    std::vector<float> expected_output = {
        -0.04089922, 0.35108989,  0.30442458,  0.39546335,  1.15422225,  0.10419128,  -0.19301927, 0.01070970,
        0.43977541,  0.89119899,  -0.51436460, 1.99256825,  1.41077507,  0.38642293,  0.17583044,  0.03320138,
        1.16508031,  -0.24356931, 0.47440714,  -0.17844005, 0.20463173,  1.90038323,  1.14138567,  0.34504607,
        0.16403235,  -0.24976699, 0.29362509,  0.34502214,  0.41751838,  1.09390712,  0.12354189,  1.83025289,
        1.05569196,  0.34413773,  0.35469764,  -0.69760042, 0.76338542,  1.75443077,  0.44126555,  0.18181801,
        0.73277575,  0.45443264,  0.17068321,  0.36591727,  0.72869974,  -0.56090516, 0.14415455,  1.47314119,
        0.42908576,  0.73084539,  -0.22373237, 2.26550221,  0.05606699,  0.39417523,  0.35234636,  0.78569502,
        0.77521765,  -0.65131050, 0.40168875,  0.45527256,  0.38715565,  0.98521245,  2.21446753,  0.36345237,
        -0.33269632, 0.36558092,  1.36846578,  1.37523413,  0.33698002,  0.28889543,  -0.40639281, 1.01643157,
        0.59668219,  0.39197800,  1.03101778,  0.02551098,  -0.03612846, -0.01371557, 0.43444607,  0.96746695,
        0.60583955,  -0.10362893, 0.40574494,  0.38046724,  0.87445319,  -0.00880148, -0.15437943, 0.08118075,
        0.44650543,  0.85956848,  -0.27865338, 2.10837507,  0.04798460,  0.43948367,  -0.10185169, 0.19978794,
        1.32323360,  1.20525467,  0.44288942,  -0.84200430, 0.52563053,  0.69949460,  0.73987913,  0.34668452,
        0.74545687,  0.57696682,  0.22452033,  -0.27099937, 0.39649010,  0.87083614,  -0.18965788, 0.58206403,
        -0.08108193, 0.42067638,  1.05117214,  -0.34287399, 0.20424896,  0.27994895,  0.46011117,  0.70890665,
    };
    std::vector<int> expected_mask_index = {
        6,
        5,
        5,
    };

    auto test_case = test::TestCase(function, s_device);
    test_case.add_input<int>(Shape{3, 8}, input_ids);
    test_case.add_input<int>(Shape{3, 8}, segment_ids);
    test_case.add_input<float>(Shape{10, 5}, word_embeddings);
    test_case.add_input<float>(Shape{8, 5}, position_embeddings);
    test_case.add_input<float>(Shape{3, 5}, segment_embeddings);
    test_case.add_input<float>(Shape{5}, gamma);
    test_case.add_input<float>(Shape{5}, beta);
    test_case.add_input<int>(Shape{3, 8}, mask);
    test_case.add_expected_output<float>(Shape{3, 8, 5}, expected_output);
    test_case.add_expected_output<int>(Shape{3}, expected_mask_index);
    test_case.run_with_tolerance_as_fp(1e-6);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention) {
    const auto function =
        onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.91475844, 0.91523546, 0.82536930, 0.37491974, 0.22384071, 0.05941105, 0.01902100, 0.70131350,
        0.09603709, 0.44200060, 0.53106076, 0.79464376, 0.35469049, 0.25225943, 0.25179818, 0.29592562,
        0.24836586, 0.65088797, 0.93126643, 0.67980725, 0.85708112, 0.59808528, 0.46321425, 0.19301885,
    };
    std::vector<float> output = {
        0.07966283, 0.10783536, -0.19424979, 0.54514766, 0.07965867, 0.10783093, -0.19424866, 0.54510003,
        0.07965846, 0.10783067, -0.19424550, 0.54509139, 0.07966217, 0.10783640, -0.19424903, 0.54512268,
        0.06940663, 0.10962760, -0.19698445, 0.53492010, 0.06940675, 0.10962828, -0.19698484, 0.53492326,
        0.06940714, 0.10963022, -0.19698712, 0.53494006, 0.06940673, 0.10962812, -0.19698519, 0.53492481,
    };

    test_case.add_input<float>(input);
    test_case.add_expected_output<float>(output);
    test_case.run_with_tolerance_as_fp(1e-7);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_qkv_hidden_sizes) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_qkv_hidden_sizes.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.56477863, 0.60309958, 0.35158035, 0.03123519, 0.81918180, 0.76905495, 0.47219241, 0.72016627,
        0.59377003, 0.91380632, 0.56797302, 0.34846428, 0.83839595, 0.16394103, 0.34676281, 0.09161621,
        0.45562279, 0.23317528, 0.37197968, 0.06727808, 0.08500192, 0.84915495, 0.68266946, 0.00227691,
    };
    std::vector<float> output = {
        -0.59370947, -0.30300471, 0.12048547, -0.09029539, 0.08041390, 0.10250041, -0.19381392, 0.55126983,
        -0.59370828, -0.30301332, 0.12049319, -0.09029691, 0.08041921, 0.10250521, -0.19381438, 0.55127531,
        -0.59370869, -0.30301058, 0.12049074, -0.09029643, 0.08041564, 0.10250199, -0.19381410, 0.55127168,
        -0.59370929, -0.30300608, 0.12048667, -0.09029562, 0.08041184, 0.10249855, -0.19381374, 0.55126774,
        -0.59681994, -0.26327702, 0.07638434, -0.06311120, 0.06671587, 0.10916986, -0.19412412, 0.51977092,
        -0.59682053, -0.26328400, 0.07638102, -0.06311222, 0.06671817, 0.10917170, -0.19412397, 0.51977223,
        -0.59682077, -0.26328647, 0.07637984, -0.06311259, 0.06671739, 0.10917108, -0.19412403, 0.51977175,
        -0.59682101, -0.26328778, 0.07637922, -0.06311278, 0.06671065, 0.10916568, -0.19412443, 0.51976782,
    };

    test_case.add_input<float>(input);
    test_case.add_expected_output<float>(output);
    test_case.run_with_tolerance_as_fp(1e-6);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_unidirectional) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_unidirectional.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.89578921, 0.42421508, 0.35630688, 0.77461642, 0.65753633, 0.09723099, 0.62597734, 0.72117692,
        0.57636845, 0.17104276, 0.13245547, 0.59879875, 0.15624641, 0.44903454, 0.50483286, 0.92975074,
        0.36934483, 0.29919949, 0.57185954, 0.83036488, 0.08384345, 0.20378476, 0.74684393, 0.46716982,
    };
    std::vector<float> output = {
        0.05604819, 0.09000472, -0.19437021, 0.52487367, 0.06211422, 0.08740954, -0.19139624, 0.52762908,
        0.06708897, 0.08992603, -0.19214047, 0.53631783, 0.06896879, 0.10248676, -0.19485690, 0.53477794,
        0.08577005, 0.12807365, -0.19762954, 0.54432857, 0.06929274, 0.10893210, -0.19599904, 0.53184807,
        0.07348281, 0.10215081, -0.19280069, 0.53552240, 0.07861833, 0.10517240, -0.19285706, 0.54126489,
    };
    std::vector<float> present = {
        -0.60427380, -0.25958878, -0.59609234, -0.24055196, -0.59613681, -0.30088067, -0.59633607, -0.33270463,
        0.06899665,  -0.09284544, 0.08059876,  -0.06146053, 0.11841078,  -0.10019838, 0.10605468,  -0.09273906,
        -0.59036821, -0.32410735, -0.60532302, -0.25127757, -0.58926487, -0.25271094, -0.58640373, -0.31730092,
        0.12509561,  -0.07968873, 0.06005794,  -0.08937149, 0.10523240,  -0.05083811, 0.14162725,  -0.07438751,
        0.05604819,  0.09000472,  0.06819826,  0.08480665,  0.07700446,  0.09494394,  0.07459175,  0.14003153,
        -0.19437021, 0.52487367,  -0.18843602, 0.53037173,  -0.19362189, 0.55360907,  -0.20299932, 0.53020388,
        0.08577005,  0.12807365,  0.05276009,  0.08972625,  0.08190014,  0.08852972,  0.09400313,  0.11423884,
        -0.19762954, 0.54432857,  -0.19435294, 0.51924801,  -0.18643703, 0.54280555,  -0.19302703, 0.55837619,
    };

    test_case.add_input<float>(input);
    test_case.add_expected_output<float>(output);
    test_case.add_expected_output<float>(present);
    test_case.run_with_tolerance_as_fp(1e-7);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_mask_index_1) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_mask_index_1.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.02841483, 0.47845092, 0.14633700, 0.54597300, 0.40160629, 0.55281311, 0.14931096, 0.64483738,
        0.96559167, 0.05262021, 0.12391864, 0.20093553, 0.74290562, 0.19367455, 0.19253619, 0.41593507,
        0.91188699, 0.61606920, 0.72673517, 0.86981291, 0.19963337, 0.22747350, 0.34308898, 0.57267183,
    };
    std::vector<int> mask_index = {
        0,
        1,
    };
    std::vector<float> output = {
        0.08298690, 0.12711772, -0.19757506, 0.54029012, 0.08298548, 0.12711433, -0.19757731, 0.54031140,
        0.08298430, 0.12711799, -0.19757695, 0.54031777, 0.08298548, 0.12711433, -0.19757444, 0.54028159,
        0.05380550, 0.10459180, -0.19593412, 0.50907606, 0.05380550, 0.10459180, -0.19593412, 0.50907606,
        0.05380550, 0.10459180, -0.19593412, 0.50907606, 0.05380550, 0.10459180, -0.19593412, 0.50907606,
    };
    std::vector<float> present = {
        -0.58437425, -0.29483819, -0.59927911, -0.30336475, -0.59104657, -0.37327260, -0.59078789, -0.29863101,
        0.11751597,  -0.04114649, 0.09933343,  -0.09884726, 0.16250694,  -0.12028439, 0.09319257,  -0.05129660,
        -0.60341775, -0.25221461, -0.58933026, -0.31912822, -0.59271193, -0.25470981, -0.59399152, -0.32643768,
        0.05398282,  -0.07468132, 0.14743008,  -0.09407346, 0.10399222,  -0.06682440, 0.11632499,  -0.08986320,
        0.09104910,  0.12973849,  0.06917210,  0.11059431,  0.09356256,  0.12594685,  0.07814129,  0.14221822,
        -0.19329809, 0.53526556,  -0.19787431, 0.53673857,  -0.20045389, 0.57165766,  -0.19869246, 0.51749766,
        0.05380550,  0.10459180,  0.09169570,  0.09892380,  0.07746917,  0.08042616,  0.07953370,  0.12909687,
        -0.19593412, 0.50907606,  -0.19202785, 0.56904894,  -0.18689045, 0.54643762,  -0.19969353, 0.53976399,
    };

    test_case.add_input<float>(input);
    test_case.add_input<int>(mask_index);
    test_case.add_expected_output<float>(output);
    test_case.add_expected_output<float>(present);
    test_case.run_with_tolerance_as_fp();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_mask_index_2) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_mask_index_2.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.75259578, 0.81492645, 0.46713001, 0.29483622, 0.06768602, 0.95105755, 0.32065326, 0.52417183,
        0.73136383, 0.77176476, 0.60997742, 0.64625764, 0.16311000, 0.89680773, 0.01331447, 0.42468646,
        0.58711547, 0.00345124, 0.13053808, 0.46278623, 0.13786320, 0.65182054, 0.74864876, 0.81506181,
    };
    std::vector<int> mask_index = {
        3,
        3,
        1,
        1,
    };
    std::vector<float> output = {
        0.07524174, 0.11320241, -0.19909523, 0.54785377, 0.06825337, 0.13981669, -0.20774621, 0.53718704,
        0.07531278, 0.12957911, -0.20330518, 0.54547405, 0.07531209, 0.12958010, -0.20330583, 0.54547292,
        0.08900890, 0.11150353, -0.18931937, 0.53757656, 0.07915881, 0.10416336, -0.18914750, 0.52921104,
        0.08285815, 0.11462159, -0.19115375, 0.53077918, 0.08285838, 0.11462225, -0.19115454, 0.53077984,
    };
    std::vector<float> present = {
        -0.59630549, -0.28110915, -0.60274345, -0.36154836, -0.59437746, -0.33717164, -0.60134649, -0.29849592,
        0.11169122,  -0.09345293, 0.11103803,  -0.13096604, 0.13131849,  -0.10597084, 0.10463209,  -0.11332577,
        -0.57949269, -0.27235535, -0.58941406, -0.25372508, -0.58658379, -0.28718373, -0.59821802, -0.32433146,
        0.13244939,  -0.02865628, 0.09308393,  -0.04083736, 0.10948701,  -0.04423397, 0.13060363,  -0.12316251,
        0.07509718,  0.08392500,  0.06825337,  0.13981669,  0.08239168,  0.11931328,  0.06770951,  0.09240761,
        -0.19074154, 0.55260652,  -0.20774621, 0.53718704,  -0.19888818, 0.55371630,  -0.19559640, 0.54754448,
        0.09983939,  0.10603377,  0.07915881,  0.10416336,  0.08655046,  0.12505992,  0.07738422,  0.09509270,
        -0.18571433, 0.55095005,  -0.18914750, 0.52921104,  -0.19315663, 0.53234470,  -0.19601485, 0.56322992,
    };

    test_case.add_input<float>(input);
    test_case.add_input<int>(mask_index);
    test_case.add_expected_output<float>(output);
    test_case.add_expected_output<float>(present);
    test_case.run_with_tolerance_as_fp();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_mask_index_3) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_mask_index_3.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.33093750, 0.39181390, 0.14586255, 0.39709702, 0.98086524, 0.03891133, 0.72234219, 0.21966648,
        0.79986620, 0.97251678, 0.04131543, 0.43971965, 0.50185394, 0.11452501, 0.88111717, 0.76076663,
        0.31870860, 0.54107893, 0.91756296, 0.58112669, 0.99117357, 0.00256292, 0.58885485, 0.93481058,
    };
    std::vector<int> mask = {
        1,
        1,
        1,
        0,
        0,
        0,
        0,
        1,
    };
    std::vector<float> output = {
        0.07551830, 0.10666487, -0.19357042, 0.53683108, 0.07551410, 0.10666656, -0.19356072, 0.53684169,
        0.07552745, 0.10666100, -0.19358172, 0.53682435, 0.07552218, 0.10666317, -0.19358677, 0.53681952,
        0.09727416, 0.13513327, -0.20121223, 0.57003713, 0.09727416, 0.13513327, -0.20121223, 0.57003713,
        0.09727416, 0.13513327, -0.20121223, 0.57003713, 0.09727416, 0.13513327, -0.20121223, 0.57003713,
    };
    std::vector<float> present = {
        -0.59174627, -0.27471560, -0.58307797, -0.25967693, -0.60766846, -0.31754097, -0.61241394, -0.26291698,
        0.09206123,  -0.05307099, 0.12491645,  -0.03853742, 0.08732655,  -0.13050151, 0.04073093,  -0.10792807,
        -0.60556883, -0.34055573, -0.60474855, -0.28785610, -0.60757709, -0.32514900, -0.58872569, -0.37967020,
        0.09779400,  -0.13136166, 0.07915612,  -0.10649752, 0.11043755,  -0.15124020, 0.16626491,  -0.11274654,
        0.07639833,  0.11762549,  0.09370039,  0.09133558,  0.05661478,  0.11096847,  0.04019671,  0.10117501,
        -0.19371650, 0.52530587,  -0.18429738, 0.55240726,  -0.20283231, 0.53265429,  -0.20036045, 0.50568837,
        0.06171235,  0.12687264,  0.05802051,  0.10266830,  0.06172965,  0.08967118,  0.09727416,  0.13513327,
        -0.20576829, 0.53365225,  -0.19832623, 0.52809310,  -0.19971462, 0.55584043,  -0.20121223, 0.57003713,
    };

    test_case.add_input<float>(input);
    test_case.add_input<int>(mask);
    test_case.add_expected_output<float>(output);
    test_case.add_expected_output<float>(present);
    test_case.run_with_tolerance_as_fp(1e-7);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_mask_index_4) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_mask_index_4.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.23565151, 0.58627969, 0.75137484, 0.68586946, 0.62750375, 0.13284931, 0.13347220, 0.36357051,
        0.56910241, 0.48275986, 0.49440190, 0.45483324, 0.63547862, 0.97893149, 0.40630588, 0.38783622,
        0.07172249, 0.46385381, 0.99764502, 0.22219376, 0.67735291, 0.40799847, 0.74337566, 0.87263006,
    };
    std::vector<int> mask = {
        1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1,
    };
    std::vector<float> output = {
        0.07771622, 0.10724538, -0.19453585, 0.54342043, 0.07459468, 0.10934003, -0.19561143, 0.53936625,
        0.07927690, 0.10619678, -0.19399606, 0.54543519, 0.07459468, 0.10934003, -0.19561143, 0.53936625,
        0.05485561, 0.11278091, -0.20117569, 0.52096349, 0.06629646, 0.10195158, -0.19900991, 0.54654449,
        0.06491723, 0.10292297, -0.19678673, 0.53451663, 0.06549793, 0.11126325, -0.19989857, 0.53717279,
    };
    std::vector<float> present = {
        -0.59188855, -0.34495637, -0.59508181, -0.25013468, -0.59176934, -0.33229247, -0.59576762, -0.29731843,
        0.14217430,  -0.10403840, 0.08584045,  -0.06193545, 0.12358667,  -0.08588549, 0.10515238,  -0.08629489,
        -0.59092808, -0.28260738, -0.60047609, -0.30411413, -0.61210287, -0.28645760, -0.59391296, -0.34649473,
        0.12789863,  -0.08159252, 0.08122411,  -0.08866425, 0.06395009,  -0.12896645, 0.14855847,  -0.11978809,
        0.08783118,  0.12152332,  0.07067389,  0.09078297,  0.08385989,  0.13306075,  0.07459468,  0.10934003,
        -0.19849420, 0.55928540,  -0.18948570, 0.53154731,  -0.19960676, 0.54237455,  -0.19561143, 0.53936625,
        0.08509844,  0.08314656,  0.06388859,  0.12990499,  0.04582624,  0.09566365,  0.08674107,  0.10823163,
        -0.18808734, 0.56137776,  -0.20168513, 0.51830697,  -0.20066255, 0.52363914,  -0.19737384, 0.56921995,
    };

    test_case.add_input<float>(input);
    test_case.add_input<int>(mask);
    test_case.add_expected_output<float>(output);
    test_case.add_expected_output<float>(present);
    test_case.run_with_tolerance_as_fp(1e-7);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_past) {
    const auto function =
        onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_past.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.82966000, 0.77751911, 0.08977074, 0.06076468, 0.40659550, 0.19995944, 0.55544919, 0.83971608,
        0.86254036, 0.30894691, 0.80156928, 0.83092463, 0.14506543, 0.32196075, 0.42209163, 0.24465553,
        0.93944097, 0.73528159, 0.23347616, 0.60544974, 0.93329269, 0.67604774, 0.56349903, 0.26199624,
    };
    std::vector<int> mask = {
        1,
        1,
        1,
        0,
        0,
        0,
        1,
        0,
        1,
        0,
        1,
        0,
        1,
        0,
        1,
        1,
        1,
        1,
    };
    std::vector<float> past = {
        0.92467678, 0.79873562, 0.00939191, 0.34891853, 0.35521412, 0.21872006, 0.89974332, 0.74132687, 0.73566031,
        0.75168055, 0.06773245, 0.85702997, 0.76256698, 0.51739877, 0.91567177, 0.66617578, 0.88056499, 0.08436447,
        0.54744655, 0.25466520, 0.08500137, 0.19271941, 0.86525357, 0.21717627, 0.97158766, 0.42288730, 0.09890039,
        0.01148765, 0.97024685, 0.19697112, 0.67671591, 0.67960924, 0.46656516, 0.30850092, 0.73536104, 0.73938161,
        0.91650903, 0.57628596, 0.51164514, 0.11695814, 0.79792547, 0.97192264, 0.29246020, 0.41030061, 0.19014873,
        0.90233624, 0.84986305, 0.26141909, 0.84528726, 0.81416380, 0.00429944, 0.31476986, 0.00440918, 0.77413058,
        0.13409913, 0.20965169, 0.61764991, 0.55266041, 0.56107825, 0.42051074, 0.16804738, 0.80362344, 0.52392679,
        0.27550557, 0.66738850, 0.39348483, 0.31801429, 0.30325863, 0.37068403, 0.92767614, 0.60799408, 0.01458820,
        0.24194679, 0.59596598, 0.81762302, 0.38094005, 0.16618672, 0.92488551, 0.84298438, 0.21752745,
    };
    std::vector<float> output = {
        0.26186451, 0.45950246, -0.04001215, 0.47680017, 0.26333901, 0.46158865, -0.04006424, 0.47588652,
        0.26875457, 0.47031689, -0.03951600, 0.47674999, 0.26851410, 0.46987134, -0.03919901, 0.47629333,
        0.18083976, 0.16579385, -0.05161894, 0.63075018, 0.18228555, 0.16642828, -0.04873618, 0.63316816,
        0.18362364, 0.16702136, -0.05045432, 0.63178891, 0.18000112, 0.16541445, -0.05139139, 0.63105792,
    };
    std::vector<float> present = {
        0.92467678,  0.79873562,  0.00939191,  0.34891853,  0.35521412,  0.21872006,  0.89974332,  0.74132687,
        0.73566031,  0.75168055,  -0.59527576, -0.23625080, -0.58657664, -0.29827437, -0.59528387, -0.33578828,
        -0.59068960, -0.34870598, 0.06773245,  0.85702997,  0.76256698,  0.51739877,  0.91567177,  0.66617578,
        0.88056499,  0.08436447,  0.54744655,  0.25466520,  0.08536442,  -0.06134639, 0.11295843,  -0.04818217,
        0.14562836,  -0.12305059, 0.15695867,  -0.11161390, 0.08500137,  0.19271941,  0.86525357,  0.21717627,
        0.97158766,  0.42288730,  0.09890039,  0.01148765,  0.97024685,  0.19697112,  -0.59141791, -0.31600696,
        -0.58647990, -0.34302223, -0.59306550, -0.36427227, -0.59695083, -0.26431620, 0.67671591,  0.67960924,
        0.46656516,  0.30850092,  0.73536104,  0.73938161,  0.91650903,  0.57628596,  0.51164514,  0.11695814,
        0.11255538,  -0.07302766, 0.16620418,  -0.09871224, 0.15272795,  -0.12076923, 0.08827571,  -0.07442430,
        0.79792547,  0.97192264,  0.29246020,  0.41030061,  0.19014873,  0.90233624,  0.84986305,  0.26141909,
        0.84528726,  0.81416380,  0.07014155,  0.07749540,  0.08745074,  0.13131952,  0.08430066,  0.09709007,
        0.09247591,  0.11065811,  0.00429944,  0.31476986,  0.00440918,  0.77413058,  0.13409913,  0.20965169,
        0.61764991,  0.55266041,  0.56107825,  0.42051074,  -0.18658412, 0.53568852,  -0.19482780, 0.53271860,
        -0.19558203, 0.57155901,  -0.19633618, 0.57260245,  0.16804738,  0.80362344,  0.52392679,  0.27550557,
        0.66738850,  0.39348483,  0.31801429,  0.30325863,  0.37068403,  0.92767614,  0.08172131,  0.13249113,
        0.09947956,  0.10781212,  0.08890627,  0.12280971,  0.06911418,  0.09499176,  0.60799408,  0.01458820,
        0.24194679,  0.59596598,  0.81762302,  0.38094005,  0.16618672,  0.92488551,  0.84298438,  0.21752745,
        -0.19839945, 0.53462923,  -0.19349247, 0.57778782,  -0.20039621, 0.56689924,  -0.19190890, 0.53286803,
    };

    test_case.add_input<float>(input);
    test_case.add_input<int>(mask);
    test_case.add_input<float>(past);
    test_case.add_expected_output<float>(output);
    test_case.add_expected_output<float>(present);
    test_case.run_with_tolerance_as_fp(1e-6);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_extra_add) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_extra_add.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.14930259, 0.11199699, 0.81292826, 0.08368169, 0.05704883, 0.41276145, 0.38760167, 0.00146112,
        0.14275745, 0.54254925, 0.07962929, 0.31023681, 0.09597706, 0.60583973, 0.90233743, 0.33360451,
        0.18193199, 0.19159532, 0.07869831, 0.86026299, 0.20683478, 0.40150928, 0.93124926, 0.31805834,
    };
    std::vector<int> mask = {
        0,
        0,
        1,
        0,
        1,
        1,
        1,
        0,
    };
    std::vector<float> extra_add = {
        0.73230380, 0.61824518, 0.19738488, 0.57034588, 0.22331032, 0.53262889, 0.60098642, 0.72943515,
        0.09009175, 0.81116527, 0.47240964, 0.49679127, 0.41110733, 0.29418564, 0.93818313, 0.64175284,
        0.06807775, 0.66733366, 0.78848422, 0.48788327, 0.38806340, 0.14002480, 0.72263688, 0.22772972,
        0.24000823, 0.75820386, 0.64254439, 0.19385594, 0.95595860, 0.59840417, 0.93769604, 0.62474734,
        0.36690548, 0.76047903, 0.62352085, 0.58574778, 0.64251810, 0.78072041, 0.43344691, 0.75383639,
        0.73950553, 0.92625278, 0.05066428, 0.08448382, 0.25980917, 0.50312829, 0.97800279, 0.05422170,
        0.05171391, 0.82828254, 0.42234898, 0.95752198, 0.96325767, 0.97909677, 0.35578200, 0.48091716,
        0.03637243, 0.91552693, 0.43403026, 0.94275808, 0.51182085, 0.86773109, 0.38459453, 0.87822068,
    };
    std::vector<float> output = {
        0.06090815, 0.12919067, -0.19883196, 0.50295448, 0.06090815, 0.12919067, -0.19883196, 0.50295448,
        0.06090815, 0.12919067, -0.19883196, 0.50295448, 0.06090815, 0.12919067, -0.19883196, 0.50295448,
        0.08714182, 0.12259886, -0.19516067, 0.54010558, 0.08671370, 0.12369543, -0.19658084, 0.54502594,
        0.08458151, 0.12488046, -0.19519810, 0.53906947, 0.09063499, 0.12088943, -0.19583938, 0.54266596,
    };
    std::vector<float> present = {
        -0.59800303, -0.35666457, -0.59420627, -0.31881350, -0.59887993, -0.27025288, -0.60216135, -0.27772796,
        0.11659990,  -0.11224300, 0.09693416,  -0.07304113, 0.06023501,  -0.05941332, 0.06434284,  -0.07978789,
        -0.59005713, -0.37009716, -0.59542215, -0.27914333, -0.57998544, -0.29826957, -0.58625919, -0.28872511,
        0.15994480,  -0.11288825, 0.07906821,  -0.05991337, 0.14479136,  -0.04415035, 0.13493451,  -0.06541853,
        0.07513385,  0.14411135,  0.07505661,  0.14532046,  0.06090815,  0.12919067,  0.05788904,  0.12018456,
        -0.20586906, 0.53715372,  -0.20203318, 0.52092510,  -0.19883196, 0.50295448,  -0.19937295, 0.51055026,
        0.09417956,  0.12943678,  0.06923291,  0.12574309,  0.10221909,  0.11366953,  0.09235901,  0.09584601,
        -0.20036517, 0.56818324,  -0.19709785, 0.51547027,  -0.18871340, 0.55736589,  -0.18826833, 0.55965197,
    };

    test_case.add_input<float>(input);
    test_case.add_input<int>(mask);
    test_case.add_input<float>(extra_add);
    test_case.add_expected_output<float>(output);
    test_case.add_expected_output<float>(present);
    test_case.run_with_tolerance_as_fp(1e-7);
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_attention_dynamic_shapes) {
    const auto function = onnx_import::import_onnx_model(
        file_util::path_join(SERIALIZED_ZOO, "onnx/com.microsoft/attention_dynamic_shapes.onnx"));
    auto test_case = test::TestCase(function, s_device);

    std::vector<float> input = {
        0.42226878, 0.50984067, 0.80440795, 0.68040705, 0.93614250, 0.45104721, 0.71767306, 0.48596525,
        0.70076728, 0.04500086, 0.28930107, 0.77435863, 0.19392140, 0.90290719, 0.91955870, 0.58811885,
        0.76795286, 0.62884814, 0.23377730, 0.49212688, 0.87256873, 0.11944817, 0.57715887, 0.91886938,
    };
    std::vector<float> weights = {
        0.99377930, 0.22733542, 0.43217131, 0.60717988, 0.97224706, 0.70020503, 0.92439449, 0.41512674, 0.47728160,
        0.40306625, 0.72619593, 0.37954643, 0.36950976, 0.84305370, 0.61671126, 0.22251014, 0.73839295, 0.73471880,
        0.37428924, 0.80240524, 0.23120961, 0.06072779, 0.92840081, 0.71558088, 0.08719950, 0.51666921, 0.53768843,
        0.48113129, 0.46389169, 0.01036468, 0.37341005, 0.67195475, 0.53599644, 0.41795707, 0.58081782, 0.97939289,
    };
    std::vector<float> bias = {
        0.77122736,
        0.75600564,
        0.86177206,
        0.69982684,
        0.74719858,
        0.78054035,
        0.80007398,
        0.74902135,
        0.81258053,
        0.01575289,
        0.08463049,
        0.39671996,
    };
    std::vector<int> mask = {
        0,
        1,
        0,
        0,
        0,
        1,
        0,
        0,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        0,
        0,
        0,
    };
    std::vector<float> past = {
        0.27759778, 0.18458818, 0.63114458, 0.09953160, 0.59739488, 0.63917851, 0.18828323, 0.65625650, 0.84574437,
        0.91846281, 0.55102497, 0.27506110, 0.06816208, 0.82616585, 0.85912132, 0.88682729, 0.14730524, 0.61618829,
        0.89891797, 0.27753425, 0.57438278, 0.33753166, 0.88768929, 0.35533753, 0.30193496, 0.81678063, 0.26569194,
        0.62769043, 0.61990744, 0.59077013, 0.11058200, 0.97370809, 0.81339806, 0.57207322, 0.80417949, 0.54185718,
        0.80831683, 0.29390740, 0.29051417, 0.51964313, 0.04341308, 0.05925354, 0.82397246, 0.55753845, 0.61247689,
        0.98571628, 0.07566493, 0.37537411, 0.42080343, 0.21715857, 0.57869565, 0.55962265, 0.82500041, 0.60776925,
        0.19367239, 0.88382334, 0.20328504, 0.58192456, 0.94542676, 0.98562658, 0.64355153, 0.69856495, 0.30377558,
        0.02857198, 0.96969068, 0.48450547, 0.98341352, 0.03546083, 0.84963584, 0.94460547, 0.90907097, 0.22525074,
        0.12530145, 0.52223104, 0.09549426, 0.93127102, 0.93429947, 0.01428344, 0.74249738, 0.22606593,
    };
    std::vector<float> output = {
        1.47439122, 0.50951630, 1.17974961, 1.58501005, 1.49403512, 0.51560062, 1.18972027, 1.59668207,
        1.48384988, 0.51248586, 1.18596375, 1.59219086, 1.44181466, 0.50219649, 1.15537691, 1.55348074,
        0.83429223, 0.59521818, 0.87688094, 0.13611843, 0.82936716, 0.61004817, 0.87633312, 0.13887596,
        0.83155584, 0.59382534, 0.87496555, 0.14041223, 0.83309680, 0.58982348, 0.87517864, 0.13930768,
    };
    std::vector<float> present = {
        0.27759778, 0.18458818, 0.63114458, 0.09953160, 0.59739488, 0.63917851, 0.18828323, 0.65625650, 0.84574437,
        0.91846281, 1.90736914, 1.45914197, 2.30920029, 1.94944119, 2.12886763, 1.64736962, 1.36378694, 1.03263116,
        0.55102497, 0.27506110, 0.06816208, 0.82616585, 0.85912132, 0.88682729, 0.14730524, 0.61618829, 0.89891797,
        0.27753425, 1.68161881, 1.87394094, 1.94785213, 2.08572555, 1.90705216, 1.90777159, 1.23910809, 1.52017307,
        0.57438278, 0.33753166, 0.88768929, 0.35533753, 0.30193496, 0.81678063, 0.26569194, 0.62769043, 0.61990744,
        0.59077013, 2.02901411, 1.58923888, 2.17776394, 1.76309133, 1.74264824, 1.31485105, 1.71575761, 1.29775190,
        0.11058200, 0.97370809, 0.81339806, 0.57207322, 0.80417949, 0.54185718, 0.80831683, 0.29390740, 0.29051417,
        0.51964313, 1.66065478, 2.17192268, 1.86598253, 2.03193212, 1.52620018, 1.82728052, 1.46963060, 1.87916136,
        0.04341308, 0.05925354, 0.82397246, 0.55753845, 0.61247689, 0.98571628, 0.07566493, 0.37537411, 0.42080343,
        0.21715857, 1.56316149, 0.55312467, 1.59553123, 0.53537023, 1.64308119, 0.62742490, 1.31600118, 0.37510848,
        0.57869565, 0.55962265, 0.82500041, 0.60776925, 0.19367239, 0.88382334, 0.20328504, 0.58192456, 0.94542676,
        0.98562658, 1.33183134, 1.70965421, 1.70983100, 1.76660407, 1.46399045, 1.70318413, 0.83565855, 1.37921953,
        0.64355153, 0.69856495, 0.30377558, 0.02857198, 0.96969068, 0.48450547, 0.98341352, 0.03546083, 0.84963584,
        0.94460547, 1.60677671, 0.53308368, 1.60789728, 0.56227136, 1.50563633, 0.50456268, 1.49554634, 0.48299593,
        0.90907097, 0.22525074, 0.12530145, 0.52223104, 0.09549426, 0.93127102, 0.93429947, 0.01428344, 0.74249738,
        0.22606593, 1.59781134, 2.01703453, 1.58993423, 1.78536010, 1.21809304, 1.69219351, 1.24090374, 1.75499403,
    };

    test_case.add_input<float>(Shape{2, 4, 3}, input);
    test_case.add_input<float>(Shape{3, 12}, weights);
    test_case.add_input<float>(Shape{12}, bias);
    test_case.add_input<int>(Shape{2, 9}, mask);
    test_case.add_input<float>(Shape{2, 2, 2, 5, 2}, past);
    test_case.add_expected_output<float>(Shape{2, 4, 4}, output);
    test_case.add_expected_output<float>(Shape{2, 2, 2, 9, 2}, present);
    test_case.run_with_tolerance_as_fp(1e-6);
}
